{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TRAIN HANGUL-RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Packages Imported\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "# Import Packages\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import collections\n",
    "import argparse\n",
    "import time\n",
    "import os\n",
    "from six.moves import cPickle\n",
    "from TextLoader import *\n",
    "from Hangulpy import *\n",
    "print (\"Packages Imported\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LOAD DATASET WITH TEXTLOADER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading preprocessed files\n"
     ]
    }
   ],
   "source": [
    "data_dir    = \"data/nine_dreams\"\n",
    "batch_size  = 50\n",
    "seq_length  = 50\n",
    "data_loader = TextLoader(data_dir, batch_size, seq_length)\n",
    "# This makes \"vocab.pkl\" and \"data.npy\" in \"data/nine_dreams\"   \n",
    "#  from \"data/nine_dreams/input.txt\" "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# VOCAB AND CHARS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "type of 'data_loader.vocab' is <type 'dict'>, length is 76\n",
      "type of 'data_loader.chars' is <type 'tuple'>, length is 76\n"
     ]
    }
   ],
   "source": [
    "vocab_size = data_loader.vocab_size\n",
    "vocab = data_loader.vocab\n",
    "chars = data_loader.chars\n",
    "print ( \"type of 'data_loader.vocab' is %s, length is %d\" \n",
    "       % (type(data_loader.vocab), len(data_loader.vocab)) )\n",
    "print ( \"type of 'data_loader.chars' is %s, length is %d\" \n",
    "       % (type(data_loader.chars), len(data_loader.chars)) )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# VOCAB: DICTIONARY (CHAR->INDEX)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{u'_': 69, u'6': 59, u':': 57, u'\\n': 19, u'4': 67, u'5': 63, u'>': 75, u'!': 52, u' ': 1, u'\"': 28, u'\\u1d25': 0, u\"'\": 49, u')': 46, u'(': 45, u'-': 65, u',': 27, u'.': 24, u'\\u3131': 7, u'0': 73, u'\\u3133': 60, u'\\u3132': 29, u'\\u3135': 50, u'\\u3134': 4, u'\\u3137': 13, u'\\u3136': 44, u'\\u3139': 5, u'\\u3138': 32, u'\\u313b': 55, u'\\u313a': 48, u'\\u313c': 54, u'?': 41, u'3': 66, u'\\u3141': 12, u'\\u3140': 51, u'\\u3143': 47, u'\\u3142': 17, u'\\u3145': 10, u'\\u3144': 43, u'\\u3147': 2, u'\\u3146': 22, u'\\u3149': 40, u'\\u3148': 15, u'\\u314b': 42, u'\\u314a': 23, u'\\u314d': 31, u'\\u314c': 30, u'\\u314f': 3, u'\\u314e': 14, u'\\u3151': 34, u'\\u3150': 21, u'\\u3153': 11, u'\\u3152': 74, u'\\u3155': 18, u'\\u3154': 20, u'\\u3157': 9, u'\\u3156': 39, u'\\u3159': 53, u'\\u3158': 26, u'\\u315b': 38, u'\\u315a': 33, u'\\u315d': 36, u'\\u315c': 16, u'\\u315f': 35, u'\\u315e': 61, u'\\u3161': 8, u'\\u3160': 37, u'\\u3163': 6, u'\\u3162': 25, u'\\x1a': 72, u'9': 64, u'7': 71, u'2': 62, u'1': 58, u'\\u313f': 56, u'\\u313e': 70, u'8': 68}\n"
     ]
    }
   ],
   "source": [
    "print (data_loader.vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CHARS: LIST (INDEX->CHAR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(u'\\u1d25', u' ', u'\\u3147', u'\\u314f', u'\\u3134', u'\\u3139', u'\\u3163', u'\\u3131', u'\\u3161', u'\\u3157', u'\\u3145', u'\\u3153', u'\\u3141', u'\\u3137', u'\\u314e', u'\\u3148', u'\\u315c', u'\\u3142', u'\\u3155', u'\\n', u'\\u3154', u'\\u3150', u'\\u3146', u'\\u314a', u'.', u'\\u3162', u'\\u3158', u',', u'\"', u'\\u3132', u'\\u314c', u'\\u314d', u'\\u3138', u'\\u315a', u'\\u3151', u'\\u315f', u'\\u315d', u'\\u3160', u'\\u315b', u'\\u3156', u'\\u3149', u'?', u'\\u314b', u'\\u3144', u'\\u3136', u'(', u')', u'\\u3143', u'\\u313a', u\"'\", u'\\u3135', u'\\u3140', u'!', u'\\u3159', u'\\u313c', u'\\u313b', u'\\u313f', u':', u'1', u'6', u'\\u3133', u'\\u315e', u'2', u'5', u'9', u'-', u'3', u'4', u'8', u'_', u'\\u313e', u'7', u'\\x1a', u'0', u'\\u3152', u'>')\n",
      "ᴥ\n"
     ]
    }
   ],
   "source": [
    "print (data_loader.chars)\n",
    "# USAGE\n",
    "print (data_loader.chars[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TRAINING BATCH (IMPORTANT!!)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Type of 'x' is <type 'numpy.ndarray'>. Shape is (50, 50)\n",
      "x looks like \n",
      "[[ 3  5  0 ...,  3  4  0]\n",
      " [20  0  1 ..., 13  3  0]\n",
      " [10 11  2 ...,  1  7  3]\n",
      " ..., \n",
      " [ 1 17  6 ...,  0  1  7]\n",
      " [ 0 14  3 ..., 12  3  4]\n",
      " [ 0  7  3 ...,  1 15  3]]\n",
      "\n",
      "Type of 'y' is <type 'numpy.ndarray'>. Shape is (50, 50)\n",
      "y looks like \n",
      "[[ 5  0  1 ...,  4  0 15]\n",
      " [ 0  1  7 ...,  3  0 24]\n",
      " [11  2  0 ...,  7  3  0]\n",
      " ..., \n",
      " [17  6  0 ...,  1  7  9]\n",
      " [14  3  0 ...,  3  4  0]\n",
      " [ 7  3  0 ..., 15  3  2]]\n"
     ]
    }
   ],
   "source": [
    "x, y = data_loader.next_batch()\n",
    "print (\"Type of 'x' is %s. Shape is %s\" % (type(x), x.shape,))\n",
    "print (\"x looks like \\n%s\" % (x))\n",
    "print\n",
    "print (\"Type of 'y' is %s. Shape is %s\" % (type(y), y.shape,))\n",
    "print (\"y looks like \\n%s\" % (y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DEFINE A MULTILAYER LSTM NETWORK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network ready\n"
     ]
    }
   ],
   "source": [
    "rnn_size   = 512\n",
    "num_layers = 3\n",
    "grad_clip  = 5. # <= GRADIENT CLIPPING (PRACTICALLY IMPORTANT)\n",
    "vocab_size = data_loader.vocab_size\n",
    "\n",
    "# SELECT RNN CELL (MULTI LAYER LSTM)\n",
    "unitcell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)\n",
    "cell = tf.nn.rnn_cell.MultiRNNCell([unitcell] * num_layers)\n",
    "\n",
    "# Set paths to the graph\n",
    "input_data = tf.placeholder(tf.int32, [batch_size, seq_length])\n",
    "targets    = tf.placeholder(tf.int32, [batch_size, seq_length])\n",
    "initial_state = cell.zero_state(batch_size, tf.float32)\n",
    "\n",
    "# Set Network\n",
    "with tf.variable_scope('rnnlm'):\n",
    "    softmax_w = tf.get_variable(\"softmax_w\", [rnn_size, vocab_size])\n",
    "    softmax_b = tf.get_variable(\"softmax_b\", [vocab_size])\n",
    "    with tf.device(\"/cpu:0\"):\n",
    "        embedding = tf.get_variable(\"embedding\", [vocab_size, rnn_size])\n",
    "        inputs = tf.split(1, seq_length, tf.nn.embedding_lookup(\n",
    "                embedding, input_data))\n",
    "        inputs = [tf.squeeze(input_, [1]) for input_ in inputs]\n",
    "print (\"Network ready\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FUNCTIONS READY\n"
     ]
    }
   ],
   "source": [
    "# Output of RNN\n",
    "outputs, last_state = tf.nn.seq2seq.rnn_decoder(inputs, initial_state\n",
    "                        , cell, loop_function=None, scope='rnnlm')\n",
    "output = tf.reshape(tf.concat(1, outputs), [-1, rnn_size])\n",
    "logits = tf.nn.xw_plus_b(output, softmax_w, softmax_b)\n",
    "\n",
    "# Next word probability\n",
    "probs = tf.nn.softmax(logits)\n",
    "print (\"FUNCTIONS READY\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DEFINE LOSS FUNCTION "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LOSS FUNCTION\n"
     ]
    }
   ],
   "source": [
    "loss = tf.nn.seq2seq.sequence_loss_by_example([logits], # Input\n",
    "    [tf.reshape(targets, [-1])], # Target\n",
    "    [tf.ones([batch_size * seq_length])], # Weight\n",
    "    vocab_size)\n",
    "print (\"LOSS FUNCTION\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DEFINE COST FUNCTION "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NETWORK READY\n"
     ]
    }
   ],
   "source": [
    "cost = tf.reduce_sum(loss) / batch_size / seq_length\n",
    "\n",
    "# GRADIENT CLIPPING ! \n",
    "lr = tf.Variable(0.0, trainable=False) # <= LEARNING RATE \n",
    "tvars = tf.trainable_variables()\n",
    "grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), grad_clip)\n",
    "_optm = tf.train.AdamOptimizer(lr)\n",
    "optm = _optm.apply_gradients(zip(grads, tvars))\n",
    "\n",
    "final_state = last_state\n",
    "print (\"NETWORK READY\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OPTIMIZE NETWORK WITH LR SCHEDULING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "num_epochs    = 500\n",
    "save_every    = 1000\n",
    "learning_rate = 0.0002\n",
    "decay_rate    = 0.97\n",
    "\n",
    "save_dir = 'data/nine_dreams'\n",
    "sess = tf.Session()\n",
    "sess.run(tf.initialize_all_variables())\n",
    "summary_writer = tf.train.SummaryWriter(save_dir\n",
    "                    , graph=sess.graph)\n",
    "saver = tf.train.Saver(tf.all_variables())\n",
    "for e in range(num_epochs): # for all epochs\n",
    "\n",
    "    # LEARNING RATE SCHEDULING \n",
    "    sess.run(tf.assign(lr, learning_rate * (decay_rate ** e)))\n",
    "\n",
    "    data_loader.reset_batch_pointer()\n",
    "    state = sess.run(initial_state)\n",
    "    for b in range(data_loader.num_batches):\n",
    "        start = time.time()\n",
    "        x, y = data_loader.next_batch()\n",
    "        feed = {input_data: x, targets: y, initial_state: state}\n",
    "        # Train!\n",
    "        train_loss, state, _ = sess.run([cost, final_state, optm], feed)\n",
    "        end = time.time()\n",
    "        # PRINT \n",
    "        if b % 100 == 0:\n",
    "            print (\"%d/%d (epoch: %d), loss: %.3f, time/batch: %.3f\"  \n",
    "                   % (e * data_loader.num_batches + b\n",
    "                    , num_epochs * data_loader.num_batches\n",
    "                    , e, train_loss, end - start))\n",
    "        # SAVE MODEL\n",
    "        if (e * data_loader.num_batches + b) % save_every == 0:\n",
    "            checkpoint_path = os.path.join(save_dir, 'model.ckpt')\n",
    "            saver.save(sess, checkpoint_path\n",
    "                       , global_step = e * data_loader.num_batches + b)\n",
    "            print(\"model saved to {}\".format(checkpoint_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# IT TAKE A LOOOOOOOOT OF TIME"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
